# --------------------------------------------------------
# Swin Transformer
# Copyright (c) 2021 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Written by Ze Liu
# --------------------------------------------------------
import argparse
import datetime
import json
import os
import time
from urllib.parse import urlparse

import numpy as np
import torch
import torch.backends.cudnn as cudnn
import torch.distributed as dist
from config import get_config
from data.build import build_loader
from logger import create_logger
from lr_scheduler import build_scheduler
from optimizer import build_optimizer
from timm.loss import SoftTargetCrossEntropy
from timm.utils import AverageMeter
from utils import get_grad_norm
from utils import load_checkpoint
from utils import reduce_tensor
from utils import save_checkpoint

import volcengine_ml_platform
from samples.models.swin_transformer_pytorch.build import build_model
from volcengine_ml_platform import constant
from volcengine_ml_platform.io import tos

try:
    # noinspection PyUnresolvedReferences
    from apex import amp
except ImportError:
    amp = None

volcengine_ml_platform.init()
BUCKET = constant.get_public_examples_readonly_bucket()
USER_BUCKET = "mlplatform-public-examples-cn-beijing"


def parse_option():
    parser = argparse.ArgumentParser(
        "Swin Transformer training and evaluation scripts",
        add_help=False,
    )
    parser.add_argument(
        "--cfg",
        type=str,
        required=True,
        metavar="FILE",
        help="path to config file",
    )
    parser.add_argument(
        "--opts",
        help="Modify config options by adding 'KEY VALUE' pairs. ",
        default=None,
        nargs="+",
    )

    # easy config modification
    parser.add_argument(
        "--batch-size",
        type=int,
        help="batch size for single GPU",
    )
    parser.add_argument("--data-path", type=str, help="path to dataset")
    parser.add_argument(
        "--cache-mode",
        type=str,
        default="part",
        choices=["no", "full", "part"],
        help="no: no cache, "
        "full: cache all data, "
        "part: sharding the dataset into nonoverlapping pieces and only cache one piece",
    )
    parser.add_argument(
        "--zip",
        action="store_true",
        help="use zipped dataset instead of folder dataset",
    )
    parser.add_argument(
        "--resume",
        action="store_true",
        help="resume from checkpoint",
    )
    parser.add_argument(
        "--accumulation-steps",
        type=int,
        help="gradient accumulation steps",
    )
    parser.add_argument(
        "--use-checkpoint",
        action="store_true",
        help="whether to use gradient checkpointing to save memory",
    )
    parser.add_argument("--device", help="use cpu or gpu to train")
    parser.add_argument(
        "--amp-opt-level",
        type=str,
        default="O1",
        choices=["O0", "O1", "O2"],
        help="mixed precision opt level, if O0, no amp is used",
    )
    parser.add_argument(
        "--output",
        default="output",
        type=str,
        metavar="PATH",
        help="root of output folder, the full path is <output>/<model_name>/<tag> (default: output)",
    )
    parser.add_argument("--tag", help="tag of experiment")
    parser.add_argument(
        "--eval",
        action="store_true",
        help="Perform evaluation only",
    )
    parser.add_argument(
        "--throughput",
        action="store_true",
        help="Test throughput only",
    )
    parser.add_argument(
        "--load_pretrained",
        action="store_true",
        help="laod pretrained model",
    )

    # distributed training
    parser.add_argument(
        "--local_rank",
        type=int,
        required=True,
        help="local rank for DistributedDataParallel",
    )

    args, unparsed = parser.parse_known_args()

    config = get_config(args)

    return args, config


def get_manifest(url, file_path):
    parse_result = urlparse(url)
    bucket = parse_result.netloc.split(".")[0]
    key = parse_result.path[1:]

    tos_client = tos.TOSClient()
    tos_client.download_file(bucket, key, file_path)

    with open(file_path, encoding="utf-8") as f:
        manifest_info = json.load(f)

    return manifest_info


def main(config):
    tos_source_list = [
        "s3://{}/flower-classification/photos/train_manifest.json".format(
            BUCKET,
        ),
        "s3://{}/flower-classification/photos/val_manifest.json".format(
            BUCKET,
        ),
    ]

    train_manifest_info = get_manifest(
        tos_source_list[0],
        "./train_manifest.json",
    )
    val_manifest_info = get_manifest(tos_source_list[1], "./val_manifest.json")
    manifest_info = [train_manifest_info, val_manifest_info]
    dataset_train, dataset_val, data_loader_train, data_loader_val = build_loader(
        config,
        manifest_info,
    )

    logger.info(f"Creating model:{config.MODEL.TYPE}/{config.MODEL.NAME}")
    model = build_model(config)
    model.to(device)
    logger.info(str(model))

    optimizer = build_optimizer(config, model)
    if config.AMP_OPT_LEVEL != "O0":
        model, optimizer = amp.initialize(
            model,
            optimizer,
            opt_level=config.AMP_OPT_LEVEL,
        )
    if config.DEVICE == "cuda":
        model = torch.nn.parallel.DistributedDataParallel(
            model,
            device_ids=[config.LOCAL_RANK],
            broadcast_buffers=False,
        )
    else:
        model = torch.nn.parallel.DistributedDataParallel(
            model,
            broadcast_buffers=False,
        )
    model_without_ddp = model.module

    n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
    logger.info(f"number of params: {n_parameters}")
    if hasattr(model_without_ddp, "flops"):
        flops = model_without_ddp.flops()
        logger.info(f"number of GFLOPs: {flops / 1e9}")

    lr_scheduler = build_scheduler(config, optimizer, len(data_loader_train))

    # if config.AUG.MIXUP > 0.:
    #     # smoothing is handled with mixup label transform
    criterion = SoftTargetCrossEntropy()
    # elif config.MODEL.LABEL_SMOOTHING > 0.:
    #     criterion = LabelSmoothingCrossEntropy(smoothing=config.MODEL.LABEL_SMOOTHING)
    # else:
    # criterion = torch.nn.CrossEntropyLoss()

    max_accuracy = 0.0

    client = tos.TOSClient()
    if config.MODEL.LOAD_CHECKPOINT:
        client.download_file(
            file_path="./ckpt.pth",
            bucket=USER_BUCKET,
            key="flower-classification/checkpoints/pytorch_ckpt.pth",
        )
    elif config.MODEL.LOAD_PRETRAINED:
        client.download_file(
            file_path="./swin_tiny_patch4_window7_224.pth",
            bucket=BUCKET,
            key="flower-classification/swin_tiny_patch4_window7_224.pth",
        )
    max_accuracy = load_checkpoint(
        config,
        model_without_ddp,
        optimizer,
        lr_scheduler,
        logger,
    )
    acc1, acc5, loss = validate(config, data_loader_val, model)
    logger.info(
        f"Accuracy of the network on the {len(dataset_val)} test images: {acc1:.1f}%",
    )
    if config.EVAL_MODE:
        return

    logger.info("Start training")
    start_time = time.time()
    for epoch in range(config.TRAIN.START_EPOCH, config.TRAIN.EPOCHS):
        data_loader_train.sampler.set_epoch(epoch)
        print(epoch)

        train_one_epoch(
            config,
            model,
            criterion,
            data_loader_train,
            optimizer,
            epoch,
            lr_scheduler,
        )
        if dist.get_rank() == 0 and (
            epoch % config.SAVE_FREQ == 0 or epoch == (config.TRAIN.EPOCHS - 1)
        ):
            save_checkpoint(
                config,
                epoch,
                model_without_ddp,
                max_accuracy,
                optimizer,
                lr_scheduler,
                logger,
                USER_BUCKET,
            )

        acc1, acc5, loss = validate(config, data_loader_val, model)
        logger.info(
            f"Accuracy of the network on the {len(dataset_val)} test images: {acc1:.1f}%",
        )
        max_accuracy = max(max_accuracy, acc1)
        logger.info(f"Max accuracy: {max_accuracy:.2f}%")

    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
    logger.info(f"Training time {total_time_str}")


def train_one_epoch(
    config,
    model,
    criterion,
    data_loader,
    optimizer,
    epoch,
    lr_scheduler,
):
    model.train()
    optimizer.zero_grad()

    num_steps = len(data_loader)
    batch_time = AverageMeter()
    loss_meter = AverageMeter()
    norm_meter = AverageMeter()

    start = time.time()
    end = time.time()
    for idx, (samples, targets) in enumerate(data_loader):
        samples = samples.to(device, non_blocking=True)
        targets = targets.to(device, non_blocking=True)
        outputs = model(samples)

        if config.TRAIN.ACCUMULATION_STEPS > 1:
            loss = criterion(outputs, targets)
            loss = loss / config.TRAIN.ACCUMULATION_STEPS
            if config.AMP_OPT_LEVEL != "O0":
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
                if config.TRAIN.CLIP_GRAD:
                    grad_norm = torch.nn.utils.clip_grad_norm_(
                        amp.master_params(optimizer),
                        config.TRAIN.CLIP_GRAD,
                    )
                else:
                    grad_norm = get_grad_norm(amp.master_params(optimizer))
            else:
                loss.backward()
                if config.TRAIN.CLIP_GRAD:
                    grad_norm = torch.nn.utils.clip_grad_norm_(
                        model.parameters(),
                        config.TRAIN.CLIP_GRAD,
                    )
                else:
                    grad_norm = get_grad_norm(model.parameters())
            if (idx + 1) % config.TRAIN.ACCUMULATION_STEPS == 0:
                optimizer.step()
                optimizer.zero_grad()
                lr_scheduler.step_update(epoch * num_steps + idx)
        else:
            loss = criterion(outputs, targets)
            optimizer.zero_grad()
            if config.AMP_OPT_LEVEL != "O0":
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
                if config.TRAIN.CLIP_GRAD:
                    grad_norm = torch.nn.utils.clip_grad_norm_(
                        amp.master_params(optimizer),
                        config.TRAIN.CLIP_GRAD,
                    )
                else:
                    grad_norm = get_grad_norm(amp.master_params(optimizer))
            else:
                loss.backward()
                if config.TRAIN.CLIP_GRAD:
                    grad_norm = torch.nn.utils.clip_grad_norm_(
                        model.parameters(),
                        config.TRAIN.CLIP_GRAD,
                    )
                else:
                    grad_norm = get_grad_norm(model.parameters())
            optimizer.step()
            lr_scheduler.step_update(epoch * num_steps + idx)

        if config.DEVICE == "cuda":
            torch.cuda.synchronize()

        loss_meter.update(loss.item(), targets.size(0))
        norm_meter.update(grad_norm)
        batch_time.update(time.time() - end)
        examples_per_second = samples.shape[0] / (time.time() - end)
        end = time.time()

        if idx % config.PRINT_FREQ == 0:
            lr = optimizer.param_groups[0]["lr"]
            if config.DEVICE == "cuda":
                memory_used = torch.cuda.max_memory_allocated() / (1024.0 * 1024.0)
                etas = batch_time.avg * (num_steps - idx)
                logger.info(
                    f"Train: [{epoch}/{config.TRAIN.EPOCHS}][{idx}/{num_steps}]\t"
                    f"eta {datetime.timedelta(seconds=int(etas))} lr {lr:.6f}\t"
                    f"time {batch_time.val:.4f} ({batch_time.avg:.4f})\t"
                    f"examples_per_second {examples_per_second:.4f}\t"
                    f"loss {loss_meter.val:.4f} ({loss_meter.avg:.4f})\t"
                    f"grad_norm {norm_meter.val:.4f} ({norm_meter.avg:.4f})\t"
                    f"mem {memory_used:.0f}MB",
                )
            else:
                etas = batch_time.avg * (num_steps - idx)
                logger.info(
                    f"Train: [{epoch}/{config.TRAIN.EPOCHS}][{idx}/{num_steps}]\t"
                    f"eta {datetime.timedelta(seconds=int(etas))} lr {lr:.6f}\t"
                    f"time {batch_time.val:.4f} ({batch_time.avg:.4f})\t"
                    f"examples_per_second {examples_per_second:.4f}\t"
                    f"loss {loss_meter.val:.4f} ({loss_meter.avg:.4f})\t"
                    f"grad_norm {norm_meter.val:.4f} ({norm_meter.avg:.4f})\t",
                )

    epoch_time = time.time() - start
    logger.info(
        f"EPOCH {epoch} training takes {datetime.timedelta(seconds=int(epoch_time))}",
    )


def accuracy(output, target, topk=(1,)):
    """Computes the accuracy over the k top predictions for the specified values of k"""

    maxk = max(topk)
    batch_size = target.size(0)
    _, pred = output.topk(maxk, 1, True, True)
    pred = pred.t()
    target = target.argmax(dim=-1, keepdim=True)
    correct = pred.eq(target.reshape(1, -1).expand_as(pred))
    return [correct[:k].reshape(-1).float().sum(0) * 100.0 / batch_size for k in topk]


@torch.no_grad()
def validate(config, data_loader, model):
    # criterion = torch.nn.CrossEntropyLoss()
    criterion = SoftTargetCrossEntropy()
    model.eval()

    batch_time = AverageMeter()
    loss_meter = AverageMeter()
    acc1_meter = AverageMeter()
    acc5_meter = AverageMeter()

    end = time.time()
    for idx, (images, target) in enumerate(data_loader):
        images = images.to(device, non_blocking=True)
        target = target.to(device, non_blocking=True)

        # compute output
        output = model(images)

        # measure accuracy and record loss

        loss = criterion(output, target)
        acc1, acc5 = accuracy(output, target, topk=(1, 5))

        acc1 = reduce_tensor(acc1)
        acc5 = reduce_tensor(acc5)
        loss = reduce_tensor(loss)

        loss_meter.update(loss.item(), target.size(0))
        acc1_meter.update(acc1.item(), target.size(0))
        acc5_meter.update(acc5.item(), target.size(0))

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if idx % config.PRINT_FREQ == 0:
            if config.DEVICE == "cuda":
                memory_used = torch.cuda.max_memory_allocated() / (1024.0 * 1024.0)
                logger.info(
                    f"Test: [{idx}/{len(data_loader)}]\t"
                    f"Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t"
                    f"Loss {loss_meter.val:.4f} ({loss_meter.avg:.4f})\t"
                    f"Acc@1 {acc1_meter.val:.3f} ({acc1_meter.avg:.3f})\t"
                    f"Acc@5 {acc5_meter.val:.3f} ({acc5_meter.avg:.3f})\t"
                    f"Mem {memory_used:.0f}MB",
                )
            else:
                logger.info(
                    f"Test: [{idx}/{len(data_loader)}]\t"
                    f"Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t"
                    f"Loss {loss_meter.val:.4f} ({loss_meter.avg:.4f})\t"
                    f"Acc@1 {acc1_meter.val:.3f} ({acc1_meter.avg:.3f})\t"
                    f"Acc@5 {acc5_meter.val:.3f} ({acc5_meter.avg:.3f})\t",
                )
    logger.info(f" * Acc@1 {acc1_meter.avg:.3f} Acc@5 {acc5_meter.avg:.3f}")
    return acc1_meter.avg, acc5_meter.avg, loss_meter.avg


@torch.no_grad()
def throughput(data_loader, model, logger):
    model.eval()

    for idx, (images, _) in enumerate(data_loader):
        images = images.to(device, non_blocking=True)
        batch_size = images.shape[0]
        for i in range(50):
            model(images)
        if config.DEVICE == "cuda":
            torch.cuda.synchronize()
        logger.info("throughput averaged with 30 times")
        tic1 = time.time()
        for i in range(30):
            model(images)
        if config.DEVICE == "cuda":
            torch.cuda.synchronize()
        tic2 = time.time()
        logger.info(
            f"batch_size {batch_size} throughput {30 * batch_size / (tic2 - tic1)}",
        )
        return


if __name__ == "__main__":
    _, config = parse_option()

    if config.AMP_OPT_LEVEL != "O0":
        assert amp is not None, "amp not installed!"

    if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
        rank = int(os.environ["RANK"])
        world_size = int(os.environ["WORLD_SIZE"])
        print(f"RANK and WORLD_SIZE in environ: {rank}/{world_size}")
    else:
        rank = -1
        world_size = -1

    device = config.DEVICE
    if device == "cuda":
        torch.cuda.set_device(config.LOCAL_RANK)
    torch.distributed.init_process_group(
        backend="gloo",
        init_method="env://",
        world_size=world_size,
        rank=rank,
    )
    torch.distributed.barrier()

    seed = config.SEED + dist.get_rank()

    torch.manual_seed(seed)
    np.random.seed(seed)
    cudnn.benchmark = True

    # linear scale the learning rate according to total batch size, may not be optimal
    linear_scaled_lr = (
        config.TRAIN.BASE_LR * config.DATA.BATCH_SIZE * dist.get_world_size() / 512.0
    )
    linear_scaled_warmup_lr = (
        config.TRAIN.WARMUP_LR * config.DATA.BATCH_SIZE * dist.get_world_size() / 512.0
    )
    linear_scaled_min_lr = (
        config.TRAIN.MIN_LR * config.DATA.BATCH_SIZE * dist.get_world_size() / 512.0
    )
    # gradient accumulation also need to scale the learning rate
    if config.TRAIN.ACCUMULATION_STEPS > 1:
        linear_scaled_lr = linear_scaled_lr * config.TRAIN.ACCUMULATION_STEPS
        linear_scaled_warmup_lr = (
            linear_scaled_warmup_lr * config.TRAIN.ACCUMULATION_STEPS
        )
        linear_scaled_min_lr = linear_scaled_min_lr * config.TRAIN.ACCUMULATION_STEPS
    config.defrost()
    config.TRAIN.BASE_LR = linear_scaled_lr
    config.TRAIN.WARMUP_LR = linear_scaled_warmup_lr
    config.TRAIN.MIN_LR = linear_scaled_min_lr
    config.freeze()

    os.makedirs(config.OUTPUT, exist_ok=True)
    logger = create_logger(
        output_dir=config.OUTPUT,
        dist_rank=dist.get_rank(),
        name=f"{config.MODEL.NAME}",
    )

    if dist.get_rank() == 0:
        path = os.path.join(config.OUTPUT, "config.json")
        with open(path, "w") as f:
            f.write(config.dump())
        logger.info(f"Full config saved to {path}")

    # print config
    logger.info(config.dump())

    main(config)
